from mindspore import nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore import context
import os
import numpy as np
from cells import SigmoidCrossEntropyWithLogits, GenWithLossCell, DisWithLossCell, TrainOneStepCell
import matplotlib.pyplot as plt


class DataGenerator:
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma

    def sample(self, N):
        samples = np.random.normal(self.mu, self.sigma, size=N)
        return samples


class Discriminator(nn.Cell):
    def __init__(self, hidden_dim, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell(
            [nn.Dense(1, hidden_dim),
             nn.LeakyReLU(),
             nn.Dense(hidden_dim, 1)])

    def construct(self, x):
        return self.model(x)


class Generator(nn.Cell):
    def __init__(self, input_dim, hidden_dim, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell([
            nn.Dense(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dense(hidden_dim, 1)
        ])

    def construct(self, x):
        return self.model(x)


def normfun(x, mu, sigma):
    return np.exp(-((x - mu)**2) /
                  (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))


def draw(gen_data, step):
    x = np.arange(-5, 5, 0.2)
    y = normfun(x, 0, 1)
    plt.plot(x, y, 'r', linewidth=3)

    mean = gen_data.mean()
    std = gen_data.std()
    y = normfun(x, mean, std)
    plt.plot(x, y, 'b', linewidth=3)

    plt.xlabel('value')
    plt.ylabel('Probability')
    plt.savefig("./images/step{}.png".format(step))
    plt.clf()


np.random.seed(58)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
batch_size = 256
input_dim = 1
hidden_dim = 10
steps = 20000
lr = 0.0001

netG = Generator(input_dim, hidden_dim)
netD = Discriminator(hidden_dim)
loss = SigmoidCrossEntropyWithLogits()
netG_with_loss = GenWithLossCell(netG, netD, loss)
netD_with_loss = DisWithLossCell(netG, netD, loss)
optimizerG = nn.Adam(netG.trainable_params(), lr, beta1=0.5)
optimizerD = nn.Adam(netD.trainable_params(), lr, beta1=0.5)

net_train = TrainOneStepCell(netG_with_loss, netD_with_loss, optimizerG,
                             optimizerD)
netG.set_train()
netD.set_train()

dataset = DataGenerator(0, 1)
for step in range(steps):
    real_data = dataset.sample((batch_size, 1))
    real_data = Tensor(real_data, dtype=mstype.float32)
    latent_code = Tensor(np.random.uniform(low=0.0,
                                           high=1.0,
                                           size=(batch_size, input_dim)),
                         dtype=mstype.float32)
    dout, gout = net_train(real_data, latent_code)
    if (step + 1) % 500 == 0:
        print("step {}, d_loss is {:.4f}, g_loss is {:.4f}".format(
            step + 1, dout.asnumpy(), gout.asnumpy()))
        latent_code = Tensor(np.random.uniform(low=0.0,
                                               high=1.0,
                                               size=(batch_size, input_dim)),
                             dtype=mstype.float32)
        gen_data = netG(latent_code).asnumpy()
        draw(gen_data, step + 1)
